package limiter

import (
	"context"
	"sync"
	"time"

	"golang.org/x/time/rate"
)

var (
	Every = rate.Every

	globalLimiters = &TokenLimiters{
		limiters: &sync.Map{},
	}

	once = sync.Once{}
)

type (
	Limit       = rate.Limit
	Reservation = rate.Reservation
)

type TokenLimiters struct {
	limiters *sync.Map
}

type TokenLimiter struct {
	limiter *rate.Limiter
	lastGet time.Time
	key     string
}

func NewTokenLimiter(r rate.Limit, b int, key string) *TokenLimiter {
	once.Do(func() {
		go globalLimiters.clearLimiter()
	})
	keyLimiter := globalLimiters.getLimiter(r, b, key)
	return keyLimiter
}

func (l *TokenLimiter) Allow() bool {
	l.lastGet = time.Now()
	return l.limiter.Allow()
}

func (l *TokenLimiter) Wait(ctx context.Context) error {
	l.lastGet = time.Now()
	return l.limiter.Wait(ctx)
}

func (l *TokenLimiter) Reserve() *Reservation {
	l.lastGet = time.Now()
	return l.limiter.Reserve()
}

// r:往桶里放Token的速率 b:令牌桶的大小 key:可对某id\ip做限制
func (ls *TokenLimiters) getLimiter(r rate.Limit, b int, key string) *TokenLimiter {
	limiter, ok := ls.limiters.Load(key)
	if ok {
		return limiter.(*TokenLimiter)
	}

	l := &TokenLimiter{
		limiter: rate.NewLimiter(r, b),
		lastGet: time.Now(),
		key:     key,
	}
	ls.limiters.Store(key, l)
	return l
}

// 清除过期的限流器
func (ls *TokenLimiters) clearLimiter() {
	for {
		time.Sleep(1 * time.Minute)
		ls.limiters.Range(func(key, value interface{}) bool {
			//超过1分钟
			if time.Now().Unix()-value.(*TokenLimiter).lastGet.Unix() > 60 {

				ls.limiters.Delete(key)
			}
			return true
		})
	}
}
